{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Regression with Images and Text\n",
    "\n",
    "In this notebook we will go through a series of examples on how to combine all Wide & Deep components.\n",
    "\n",
    "To that aim I will use the Airbnb listings dataset for London, which you can download from [here](http://insideairbnb.com/get-the-data.html). I use this dataset simply because it contains tabular data, images and text.\n",
    "\n",
    "I have taken a sample of 1000 listings to keep the data tractable in this notebook. Also, I have preprocessed the data and prepared it for this exercise. All preprocessing steps can be found in the notebook `airbnb_data_preprocessing.ipynb` in this `examples` folder. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "from torchvision.transforms import ToTensor, Normalize\n",
    "\n",
    "from pytorch_widedeep import Trainer\n",
    "from pytorch_widedeep.preprocessing import (\n",
    "    WidePreprocessor,\n",
    "    TabPreprocessor,\n",
    "    TextPreprocessor,\n",
    "    ImagePreprocessor,\n",
    ")\n",
    "from pytorch_widedeep.models import (\n",
    "    Wide,\n",
    "    TabMlp,\n",
    "    Vision,\n",
    "    BasicRNN,\n",
    "    WideDeep,\n",
    ")\n",
    "from pytorch_widedeep.losses import RMSELoss\n",
    "from pytorch_widedeep.initializers import *\n",
    "from pytorch_widedeep.callbacks import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>host_id</th>\n",
       "      <th>description</th>\n",
       "      <th>host_listings_count</th>\n",
       "      <th>host_identity_verified</th>\n",
       "      <th>neighbourhood_cleansed</th>\n",
       "      <th>latitude</th>\n",
       "      <th>longitude</th>\n",
       "      <th>is_location_exact</th>\n",
       "      <th>property_type</th>\n",
       "      <th>...</th>\n",
       "      <th>amenity_wide_entrance</th>\n",
       "      <th>amenity_wide_entrance_for_guests</th>\n",
       "      <th>amenity_wide_entryway</th>\n",
       "      <th>amenity_wide_hallways</th>\n",
       "      <th>amenity_wifi</th>\n",
       "      <th>amenity_window_guards</th>\n",
       "      <th>amenity_wine_cooler</th>\n",
       "      <th>security_deposit</th>\n",
       "      <th>extra_people</th>\n",
       "      <th>yield</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>13913.jpg</td>\n",
       "      <td>54730</td>\n",
       "      <td>My bright double bedroom with a large window h...</td>\n",
       "      <td>4.0</td>\n",
       "      <td>f</td>\n",
       "      <td>Islington</td>\n",
       "      <td>51.56802</td>\n",
       "      <td>-0.11121</td>\n",
       "      <td>t</td>\n",
       "      <td>apartment</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>100.0</td>\n",
       "      <td>15.0</td>\n",
       "      <td>12.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>15400.jpg</td>\n",
       "      <td>60302</td>\n",
       "      <td>Lots of windows and light.  St Luke's Gardens ...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>t</td>\n",
       "      <td>Kensington and Chelsea</td>\n",
       "      <td>51.48796</td>\n",
       "      <td>-0.16898</td>\n",
       "      <td>t</td>\n",
       "      <td>apartment</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>150.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>109.50</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>17402.jpg</td>\n",
       "      <td>67564</td>\n",
       "      <td>Open from June 2018 after a 3-year break, we a...</td>\n",
       "      <td>19.0</td>\n",
       "      <td>t</td>\n",
       "      <td>Westminster</td>\n",
       "      <td>51.52098</td>\n",
       "      <td>-0.14002</td>\n",
       "      <td>t</td>\n",
       "      <td>apartment</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>350.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>149.65</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>24328.jpg</td>\n",
       "      <td>41759</td>\n",
       "      <td>Artist house, bright high ceiling rooms, priva...</td>\n",
       "      <td>2.0</td>\n",
       "      <td>t</td>\n",
       "      <td>Wandsworth</td>\n",
       "      <td>51.47298</td>\n",
       "      <td>-0.16376</td>\n",
       "      <td>t</td>\n",
       "      <td>other</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>250.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>215.60</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>25023.jpg</td>\n",
       "      <td>102813</td>\n",
       "      <td>Large, all comforts, 2-bed flat; first floor; ...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>f</td>\n",
       "      <td>Wandsworth</td>\n",
       "      <td>51.44687</td>\n",
       "      <td>-0.21874</td>\n",
       "      <td>t</td>\n",
       "      <td>apartment</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>250.0</td>\n",
       "      <td>11.0</td>\n",
       "      <td>79.35</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 223 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          id  host_id                                        description  \\\n",
       "0  13913.jpg    54730  My bright double bedroom with a large window h...   \n",
       "1  15400.jpg    60302  Lots of windows and light.  St Luke's Gardens ...   \n",
       "2  17402.jpg    67564  Open from June 2018 after a 3-year break, we a...   \n",
       "3  24328.jpg    41759  Artist house, bright high ceiling rooms, priva...   \n",
       "4  25023.jpg   102813  Large, all comforts, 2-bed flat; first floor; ...   \n",
       "\n",
       "   host_listings_count host_identity_verified  neighbourhood_cleansed  \\\n",
       "0                  4.0                      f               Islington   \n",
       "1                  1.0                      t  Kensington and Chelsea   \n",
       "2                 19.0                      t             Westminster   \n",
       "3                  2.0                      t              Wandsworth   \n",
       "4                  1.0                      f              Wandsworth   \n",
       "\n",
       "   latitude  longitude is_location_exact property_type  ...  \\\n",
       "0  51.56802   -0.11121                 t     apartment  ...   \n",
       "1  51.48796   -0.16898                 t     apartment  ...   \n",
       "2  51.52098   -0.14002                 t     apartment  ...   \n",
       "3  51.47298   -0.16376                 t         other  ...   \n",
       "4  51.44687   -0.21874                 t     apartment  ...   \n",
       "\n",
       "  amenity_wide_entrance  amenity_wide_entrance_for_guests  \\\n",
       "0                     1                                 0   \n",
       "1                     0                                 0   \n",
       "2                     0                                 0   \n",
       "3                     0                                 0   \n",
       "4                     0                                 0   \n",
       "\n",
       "   amenity_wide_entryway  amenity_wide_hallways  amenity_wifi  \\\n",
       "0                      0                      0             1   \n",
       "1                      0                      0             1   \n",
       "2                      0                      0             1   \n",
       "3                      0                      0             1   \n",
       "4                      0                      0             1   \n",
       "\n",
       "   amenity_window_guards  amenity_wine_cooler security_deposit extra_people  \\\n",
       "0                      0                    0            100.0         15.0   \n",
       "1                      0                    0            150.0          0.0   \n",
       "2                      0                    0            350.0         10.0   \n",
       "3                      0                    0            250.0          0.0   \n",
       "4                      0                    0            250.0         11.0   \n",
       "\n",
       "    yield  \n",
       "0   12.00  \n",
       "1  109.50  \n",
       "2  149.65  \n",
       "3  215.60  \n",
       "4   79.35  \n",
       "\n",
       "[5 rows x 223 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv(\"../tmp_data/airbnb/airbnb_sample.csv\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression with the defaults\n",
    "\n",
    "The set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# There are a number of columns that are already binary. Therefore, no need to one hot encode them\n",
    "crossed_cols = [(\"property_type\", \"room_type\")]\n",
    "already_dummies = [c for c in df.columns if \"amenity\" in c] + [\"has_house_rules\"]\n",
    "wide_cols = [\n",
    "    \"is_location_exact\",\n",
    "    \"property_type\",\n",
    "    \"room_type\",\n",
    "    \"host_gender\",\n",
    "    \"instant_bookable\",\n",
    "] + already_dummies\n",
    "\n",
    "cat_embed_cols = [(c, 16) for c in df.columns if \"catg\" in c] + [\n",
    "    (\"neighbourhood_cleansed\", 64),\n",
    "    (\"cancellation_policy\", 16),\n",
    "]\n",
    "continuous_cols = [\"latitude\", \"longitude\", \"security_deposit\", \"extra_people\"]\n",
    "\n",
    "# text and image colnames\n",
    "text_col = \"description\"\n",
    "img_col = \"id\"\n",
    "\n",
    "# path to pretrained word embeddings and the images\n",
    "word_vectors_path = \"../tmp_data/glove.6B/glove.6B.100d.txt\"\n",
    "img_path = \"../tmp_data/airbnb/property_picture\"\n",
    "\n",
    "# target\n",
    "target_col = \"yield\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the data\n",
    "\n",
    "I will focus here on how to prepare the data and run the model. Check notebooks 1 and 2 to see what's going on behind the scences\n",
    "\n",
    "Preparing the data is rather simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = df[target_col].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = wide_preprocessor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised\n",
      "  warnings.warn(\"Continuous columns will not be normalised\")\n"
     ]
    }
   ],
   "source": [
    "tab_preprocessor = TabPreprocessor(\n",
    "    cat_embed_cols=cat_embed_cols,\n",
    "    continuous_cols=continuous_cols,\n",
    ")\n",
    "X_tab = tab_preprocessor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The vocabulary contains 2192 tokens\n",
      "Indexing word vectors...\n",
      "Loaded 400000 word vectors\n",
      "Preparing embeddings matrix...\n",
      "2175 words in the vocabulary had ../tmp_data/glove.6B/glove.6B.100d.txt vectors and appear more than 5 times\n"
     ]
    }
   ],
   "source": [
    "text_preprocessor = TextPreprocessor(\n",
    "    word_vectors_path=word_vectors_path, text_col=text_col\n",
    ")\n",
    "X_text = text_preprocessor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reading Images from ../tmp_data/airbnb/property_picture\n",
      "Resizing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1001/1001 [00:01<00:00, 638.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing normalisation metrics\n"
     ]
    }
   ],
   "source": [
    "image_processor = ImagePreprocessor(img_col=img_col, img_path=img_path)\n",
    "X_images = image_processor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the model components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Linear model\n",
    "wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "\n",
    "# DeepDense: 2 Dense layers\n",
    "tab_mlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    continuous_cols=continuous_cols,\n",
    "    mlp_hidden_dims=[128, 64],\n",
    "    mlp_dropout=0.1,\n",
    ")\n",
    "\n",
    "# DeepText: a stack of 2 LSTMs\n",
    "basic_rnn = BasicRNN(\n",
    "    vocab_size=len(text_preprocessor.vocab.itos),\n",
    "    embed_matrix=text_preprocessor.embedding_matrix,\n",
    "    n_layers=2,\n",
    "    hidden_dim=64,\n",
    "    rnn_dropout=0.5,\n",
    ")\n",
    "\n",
    "# Pretrained Resnet 18\n",
    "resnet = Vision(pretrained_model_setup=\"resnet18\", n_trainable=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Combine them all with the \"collector\" class `WideDeep`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = WideDeep(\n",
    "    wide=wide,\n",
    "    deeptabular=tab_mlp,\n",
    "    deeptext=basic_rnn,\n",
    "    deepimage=resnet,\n",
    "    head_hidden_dims=[256, 128],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Build the trainer and fit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model, objective=\"rmse\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.28it/s, loss=115]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.62it/s, loss=94.1]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    X_text=X_text,\n",
    "    X_img=X_images,\n",
    "    target=target,\n",
    "    n_epochs=1,\n",
    "    batch_size=32,\n",
    "    val_split=0.2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Both, the Text and Image components allow FC-heads on their own (have a look to the documentation).\n",
    "\n",
    "Now let's go \"kaggle crazy\". Let's use different optimizers, initializers and schedulers for different components. Moreover, let's use a different learning rate for different parameter groups, for the `deeptabular` component"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "deep_params = []\n",
    "for childname, child in model.named_children():\n",
    "    if childname == \"deeptabular\":\n",
    "        for n, p in child.named_parameters():\n",
    "            if \"embed_layer\" in n:\n",
    "                deep_params.append({\"params\": p, \"lr\": 1e-4})\n",
    "            else:\n",
    "                deep_params.append({\"params\": p, \"lr\": 1e-3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.03)\n",
    "deep_opt = torch.optim.Adam(deep_params)\n",
    "text_opt = torch.optim.AdamW(model.deeptext.parameters())\n",
    "img_opt = torch.optim.AdamW(model.deepimage.parameters())\n",
    "head_opt = torch.optim.Adam(model.deephead.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=5)\n",
    "deep_sch = torch.optim.lr_scheduler.MultiStepLR(deep_opt, milestones=[3, 8])\n",
    "text_sch = torch.optim.lr_scheduler.StepLR(text_opt, step_size=5)\n",
    "img_sch = torch.optim.lr_scheduler.MultiStepLR(deep_opt, milestones=[3, 8])\n",
    "head_sch = torch.optim.lr_scheduler.StepLR(head_opt, step_size=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# remember, one optimizer per model components, for lr_schedures and initializers is not neccesary\n",
    "optimizers = {\n",
    "    \"wide\": wide_opt,\n",
    "    \"deeptabular\": deep_opt,\n",
    "    \"deeptext\": text_opt,\n",
    "    \"deepimage\": img_opt,\n",
    "    \"deephead\": head_opt,\n",
    "}\n",
    "schedulers = {\n",
    "    \"wide\": wide_sch,\n",
    "    \"deeptabular\": deep_sch,\n",
    "    \"deeptext\": text_sch,\n",
    "    \"deepimage\": img_sch,\n",
    "    \"deephead\": head_sch,\n",
    "}\n",
    "\n",
    "# Now...we have used pretrained word embeddings, so you do not want to\n",
    "# initialise these  embeddings. However you might still want to initialise the\n",
    "# other layers in the DeepText component. No probs, you can do that with the\n",
    "# parameter pattern and your knowledge on regular  expressions. Here we are\n",
    "# telling to the KaimingNormal initializer to NOT touch the  parameters whose\n",
    "# name contains the string word_embed.\n",
    "initializers = {\n",
    "    \"wide\": KaimingNormal,\n",
    "    \"deeptabular\": KaimingNormal,\n",
    "    \"deeptext\": KaimingNormal(pattern=r\"^(?!.*word_embed).*$\"),\n",
    "    \"deepimage\": KaimingNormal,\n",
    "}\n",
    "\n",
    "mean = [0.406, 0.456, 0.485]  # BGR\n",
    "std = [0.225, 0.224, 0.229]  # BGR\n",
    "transforms = [ToTensor, Normalize(mean=mean, std=std)]\n",
    "callbacks = [\n",
    "    LRHistory(n_epochs=10),\n",
    "    EarlyStopping,\n",
    "    ModelCheckpoint(filepath=\"model_weights/wd_out\"),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/initializers.py:34: UserWarning: No initializer found for deephead\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    objective=\"rmse\",\n",
    "    initializers=initializers,\n",
    "    optimizers=optimizers,\n",
    "    lr_schedulers=schedulers,\n",
    "    callbacks=callbacks,\n",
    "    transforms=transforms,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:19<00:00,  1.25it/s, loss=101]\n",
      "valid: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:04<00:00,  1.62it/s, loss=90.6]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model weights after training corresponds to the those of the final epoch which might not be the best performing weights. Use the 'ModelCheckpoint' Callback to restore the best epoch weights.\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(\n",
    "    X_wide=X_wide,\n",
    "    X_tab=X_tab,\n",
    "    X_text=X_text,\n",
    "    X_img=X_images,\n",
    "    target=target,\n",
    "    n_epochs=1,\n",
    "    batch_size=32,\n",
    "    val_split=0.2,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "we have only run one epoch, but let's check that the LRHistory callback records the lr values for each group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'lr_wide_0': [0.03, 0.03],\n",
       " 'lr_deeptabular_0': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_1': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_2': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_3': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_4': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_5': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_6': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_7': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_8': [0.0001, 0.0001],\n",
       " 'lr_deeptabular_9': [0.001, 0.001],\n",
       " 'lr_deeptabular_10': [0.001, 0.001],\n",
       " 'lr_deeptabular_11': [0.001, 0.001],\n",
       " 'lr_deeptabular_12': [0.001, 0.001],\n",
       " 'lr_deeptext_0': [0.001, 0.001],\n",
       " 'lr_deepimage_0': [0.001, 0.001],\n",
       " 'lr_deephead_0': [0.001, 0.001]}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.lr_history"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
